# -*- coding: utf-8 -*-
"""
Created on Mon Apr 20 21:16:52 2020

@author: xiang_yaobing
"""
import numpy as np
#import matplotlib.pyplot as plt
import random
from scipy.stats import gaussian_kde
import os
#from astropy.io import fits
#from pandas import  DataFrame
import cv2
#设置模拟数据超参数
def mkdir(path):

	folder = os.path.exists(path)

	if not folder:                   #判断是否存在文件夹如果不存在则创建为文件夹
		os.makedirs(path)            #makedirs 创建文件时如果路径不存在会创建这个路径
		print ("---  new folder...  ---")
		print ("---  OK  ---")

	else:
		print ("---  There is this folder(it had already been build)!  ---")

def generate_cluster(cluster_stars_num):#生成团星数据
    raj = np.random.normal(loc = random.uniform(-0.02,0.02),scale = random.uniform(0.01,0.055),size=[cluster_stars_num])
    dej = np.random.normal(loc = random.uniform(-0.02,0.02),scale = random.uniform(0.01,0.055),size=[cluster_stars_num])
    #团和非团的参数配比是否合适？
    label = np.zeros([cluster_stars_num])
    cluster_star = np.transpose(np.array([raj, dej, label]))
    return cluster_star

    
def generate_nocluster(nocluster_stars_num):#生成非团星数据，标签为1
    theta = np.random.uniform(0,2*np.pi,size = [nocluster_stars_num])
    radius = 0.2
    r = np.random.uniform(0, radius,size = [nocluster_stars_num])
    raj = np.sin(theta)*(r**0.5)
    dej = np.cos(theta)*(r**0.5)
    label = np.zeros([nocluster_stars_num])+1
    nocluster_star = np.transpose(np.array([raj, dej, label]))
    return nocluster_star

def data_predeal(cluster, nocluster):#数据拼接与洗牌
    mix_stars = np.random.permutation(np.concatenate([cluster, nocluster]))
    return mix_stars

def data_wash(data_deal, fa):#对数据进行高斯估计，去除个别噪声,返回处理后数据及数据参数列表
    '''
    loc_value: raj_max, raj_min, dej_max, dej_min
    pm_value: pm_mean +- pmra_std, pm_mean +- pmde_std
    八个变量是控制视图显示范围
    '''
    raj_max = np.max(data_deal[:,0])
    raj_min = np.min(data_deal[:,0])
    dej_max = np.max(data_deal[:,1])
    dej_min = np.min(data_deal[:,1])
    loc_value = [raj_min, raj_max, dej_min, dej_max]
#    pmra_std = np.std(data_deal[:,2],ddof = 1)
#    pmde_std = np.std(data_deal[:,3],ddof = 1)
#    pm_mean = [np.mean(data_deal[:,2]), np.mean(data_deal[:,3])]
    
#    rate = 1 
#    ##中误差控制率
#    pmra_max_value = pm_mean[0]+rate*pmra_std
#    pmra_min_value = pm_mean[0]-rate*pmra_std
#    pmde_max_value = pm_mean[1]+rate*pmde_std
#    pmde_min_value = pm_mean[1]-rate*pmde_std
    data_deal = data_deal[data_deal[:,0] < raj_max-fa]
    data_deal = data_deal[data_deal[:,0] > raj_min+fa]
    data_deal = data_deal[data_deal[:,1] < dej_max-fa]
    data_deal = data_deal[data_deal[:,1] > dej_min+fa]
    loc_value_s = [raj_min+fa, raj_max-fa, dej_min+fa, dej_max-fa]
#    pm_value = [pmra_min_value, pmra_max_value, pmde_min_value, pmde_max_value]
    return data_deal, loc_value, loc_value_s

def generate_main(control=1):#模拟数据生成功能主函数
#    if control==1:
#            name='cluster'
#    else:
#            name='no_cluster'
    
    cluster_stars_num = int(random.uniform(70,300)*control)
    nocluster_stars_num = int(random.uniform(2000,4500))
    cluster = generate_cluster(cluster_stars_num)
    nocluster = generate_nocluster(nocluster_stars_num)
    mix_stars = data_predeal(cluster, nocluster)
    return mix_stars

def gauss_kde_location(mix_stars, loc_value):#实现位置场的核密度估计
    loc = np.array([mix_stars[:,0],mix_stars[:,1]])
    kde = gaussian_kde(loc)
    xgrid = np.linspace(loc_value[0], loc_value[1], 60)
    ygrid = np.linspace(loc_value[2], loc_value[3], 60)
    xgrid,ygrid = np.meshgrid(xgrid,ygrid)
    z = kde.evaluate(np.vstack([xgrid.ravel(),ygrid.ravel()]))
    return z,xgrid
       
if __name__ == '__main__':
    #print (file_name)
    filePath = '..\\trainset\\simulateField\\'
    mkdir(filePath)
    num = 10000
    for i in range(num):
        print(i)
        loc_save_path = filePath +'Field'+ str(i)+'.jpg'
        data= generate_main(control=0)
        datas, loc_value, loc_value_s = data_wash(data, 0.05)
        datass, loc_value, loc_value_ss = data_wash(data, 0.1)
        
        dataList = np.zeros((60,60,3))
        z, grid = gauss_kde_location(data, loc_value)
        z = z.reshape(grid.shape)
        z = ((z-z.min())/(z.max()-z.min()))*255
        z=255-z.astype(np.int)
        #z = np.reshape(z,(grid.shape[0],grid.shape[1],1))
        dataList[:,:,0] = z
        zs, grids = gauss_kde_location(datas, loc_value_s)
        zs = zs.reshape(grid.shape)
        zs = ((zs-zs.min())/(zs.max()-zs.min()))*255
        zs=255-zs.astype(np.int)
        #zs = np.reshape(zs,(grid.shape[0],grid.shape[1],1))
        dataList[:,:,1] = zs
        zss, grids = gauss_kde_location(datass, loc_value_ss)
        zss = zss.reshape(grid.shape)
        zss = ((zss-zss.min())/(zss.max()-zss.min()))*255
        zss=255-zss.astype(np.int)
        #zss = np.reshape(zss,(grid.shape[0],grid.shape[1],1))
        #multiKde = np.reshape(np.array([[z],[zs],[zss]]),(grid.shape[0],grid.shape[1],3))
        dataList[:,:,2] = zss
        multiKde = dataList
        cv2.imwrite(loc_save_path, multiKde)
